#include "allocate.h"
#include "analyze.h"
#include "internal.h"
#include "logging.h"
#include "print.h"
#include "proprobe.h"
#include "report.h"
#include "terminate.h"
#include "trail.h"
#include "transitive.h"

#include <stddef.h>

static void
transitive_assign(kissat *solver, unsigned lit)
{
  LOG("transitive assign %s", LOGLIT(lit));
  value *values = solver->values;
  const unsigned not_lit = NOT(lit);
  assert(!values[lit]);
  assert(!values[not_lit]);
  values[lit] = 1;
  values[not_lit] = -1;
  PUSH_STACK(solver->trail, lit);
}

static void
transitive_backtrack(kissat *solver, unsigned saved)
{
  assert(saved <= SIZE_STACK(solver->trail));
  value *values = solver->values;
  while (SIZE_STACK(solver->trail) > saved)
  {
    const unsigned lit = POP_STACK(solver->trail);
    LOG("transitive unassign %s", LOGLIT(lit));
    const unsigned not_lit = NOT(lit);
    assert(values[lit] > 0);
    assert(values[not_lit] < 0);
    values[lit] = values[not_lit] = 0;
  }
  solver->propagated = saved;
  solver->level = 0;
}

static void
prioritize_binaries(kissat *solver)
{
  assert(solver->watching);
  statches large;
  INIT_STACK(large);
  watches *all_watches = solver->watches;
  for (all_literals(lit))
  {
    assert(EMPTY_STACK(large));
    watches *watches = all_watches + lit;
    watch *begin_watches = BEGIN_WATCHES(*watches), *q = begin_watches;
    const watch *end_watches = END_WATCHES(*watches), *p = q;
    while (p != end_watches)
    {
      const watch head = *q++ = *p++;
      if (head.type.binary)
        continue;
      const watch tail = *p++;
      PUSH_STACK(large, head);
      PUSH_STACK(large, tail);
      q--;
    }
    const watch *end_large = END_STACK(large);
    const watch *r = BEGIN_STACK(large);
    while (r != end_large)
      *q++ = *r++;
    assert(q == end_watches);
    CLEAR_STACK(large);
  }
  RELEASE_STACK(large);
}

static bool
transitive_reduce(kissat *solver,
                  unsigned src, uint64_t limit,
                  uint64_t *reduced_ptr, unsigned *units)
{
  bool res = false;
  assert(!VALUE(src));
  LOG("transitive reduce %s", LOGLIT(src));
  watches *all_watches = solver->watches;
  watches *src_watches = all_watches + src;
  watch *end_src = END_WATCHES(*src_watches);
  watch *begin_src = BEGIN_WATCHES(*src_watches);
  unsigned ticks = kissat_inc_cache_lines(src_watches->size, sizeof(watch));
  ADD(transitive_ticks, ticks + 1);
  solver->dps_ticks += 1 + ticks;
  const unsigned not_src = NOT(src);
  unsigned reduced = 0;
  bool failed = false;
  for (watch *p = begin_src; p != end_src; p++)
  {
    const watch src_watch = *p;
    if (!src_watch.type.binary)
      break;
    const unsigned dst = src_watch.binary.lit;
    if (dst < src)
      continue;
    if (VALUE(dst))
      continue;
    assert(solver->propagated == SIZE_STACK(solver->trail));
    unsigned saved = solver->propagated;
    assert(!solver->level);
    solver->level = 1;
    transitive_assign(solver, not_src);
    const bool redundant = src_watch.binary.redundant;
    bool transitive = false;
    unsigned propagated = 0;
    while (!transitive && !failed &&
           solver->propagated < SIZE_STACK(solver->trail))
    {
      const unsigned lit = PEEK_STACK(solver->trail, solver->propagated);
      solver->propagated++;
      propagated++;
      LOG("transitive propagate %s", LOGLIT(lit));
      assert(VALUE(lit) > 0);
      const unsigned not_lit = NOT(lit);
      watches *lit_watches = all_watches + not_lit;
      const watch *end_lit = END_WATCHES(*lit_watches);
      const watch *begin_lit = BEGIN_WATCHES(*lit_watches);
      ticks = kissat_inc_cache_lines(lit_watches->size, sizeof(watch));
      ADD(transitive_ticks, ticks + 1);
      solver->dps_ticks += ticks + 1;
      for (const watch *q = begin_lit; q != end_lit; q++)
      {
        if (p == q)
          continue;
        const watch lit_watch = *q;
        if (!lit_watch.type.binary)
          break;
        if (not_lit == src && lit_watch.binary.lit == ILLEGAL_LIT)
          continue;
        if (!redundant && lit_watch.binary.redundant)
          continue;
        const unsigned other = lit_watch.binary.lit;
        if (other == dst)
        {
          transitive = true;
          break;
        }
        const value value = VALUE(other);
        if (value < 0)
        {
          LOG("both %s and %s reachable from %s",
              LOGLIT(NOT(other)), LOGLIT(other), LOGLIT(src));
          failed = true;
          break;
        }
        if (!value)
          transitive_assign(solver, other);
      }
    }

    assert(solver->probing);
    ADD(propagations, propagated);
    ADD(probing_propagations, propagated);
    ADD(transitive_propagations, propagated);
    transitive_backtrack(solver, saved);

    if (transitive)
    {
      LOGBINARY(src, dst, "transitive reduce");
      INC(transitive_reduced);
      watches *dst_watches = all_watches + dst;
      watch dst_watch = src_watch;
      assert(dst_watch.binary.lit == dst);
      assert(dst_watch.binary.redundant == redundant);
      dst_watch.binary.lit = src;
      REMOVE_WATCHES(*dst_watches, dst_watch);
      kissat_inc_delete_binary(solver,
                           src_watch.binary.redundant,
                           src_watch.binary.hyper, src, dst);
      p->binary.lit = ILLEGAL_LIT;
      reduced++;
      res = true;
    }

    if (failed)
      break;
    if (solver->statistics.transitive_ticks > limit)
      break;
    if (TERMINATED(16))
      break;
  }
  if (solver->dps == 1 && solver->dps_ticks >= solver->dps_period)
  {
    solver->dps_ticks -= solver->dps_period;
    solver->cbk_start_new_period(solver->issuer);
  }
  if (reduced)
  {
    *reduced_ptr += reduced;
    assert(begin_src == BEGIN_WATCHES(WATCHES(src)));
    assert(end_src == END_WATCHES(WATCHES(src)));
    watch *q = begin_src;
    for (const watch *p = begin_src; p != end_src; p++)
    {
      const watch src_watch = *q++ = *p;
      if (!src_watch.type.binary)
      {
        *q++ = *++p;
        continue;
      }
      if (src_watch.binary.lit == ILLEGAL_LIT)
        q--;
    }
    assert(end_src - q == (ptrdiff_t)reduced);
    SET_END_OF_WATCHES(*src_watches, q);
  }

  if (failed)
  {
    LOG("transitive failed literal %s", LOGLIT(not_src));
    INC(failed);
    *units += 1;
    res = true;

    kissat_inc_assign_unit(solver, src);
    CHECK_AND_ADD_UNIT(src);
    ADD_UNIT_TO_PROOF(src);

    clause *conflict = kissat_inc_probing_propagate(solver, 0);
    if (conflict)
    {
      (void)kissat_inc_analyze(solver, conflict);
      assert(solver->inconsistent);
    }
    else
    {
      assert(solver->unflushed);
      kissat_inc_flush_trail(solver);
    }
  }

  return res;
}

void kissat_inc_transitive_reduction(kissat *solver)
{
  if (solver->inconsistent)
    return;
  assert(solver->watching);
  assert(solver->probing);
  assert(!solver->level);
  if (!GET_OPTION(transitive))
    return;
  if (TERMINATED(17))
    return;
  START(transitive);
  prioritize_binaries(solver);
  bool success = false;
  uint64_t reduced = 0;
  unsigned units = 0;

  SET_EFFICIENCY_BOUND(limit, transitive, transitive_ticks, search_ticks, 0);

  assert(solver->transitive < LITS);
  const unsigned end = solver->transitive;
#ifndef QUIET
  const unsigned active = solver->active;
#endif
  unsigned probed = 0;
  do
  {
    const unsigned lit = solver->transitive++;
    if (solver->transitive == LITS)
      solver->transitive = 0;
    if (!ACTIVE(IDX(lit)))
      continue;
    probed++;
    if (transitive_reduce(solver, lit, limit, &reduced, &units))
      success = true;
    if (solver->inconsistent)
      break;
    if (solver->statistics.transitive_ticks > limit)
      break;
    if (TERMINATED(18))
      break;
  } while (solver->transitive != end);
  kissat_inc_phase(solver, "transitive", GET(probings),
               "probed %u (%.0f%%): reduced %" PRIu64 ", units %u",
               probed, kissat_inc_percent(probed, 2 * active), reduced, units);
  STOP(transitive);
  REPORT(!success, 't');
#ifdef QUIET
  (void)success;
#endif
}
